Skip to content

Conversation

TIFitis
Copy link
Member

@TIFitis TIFitis commented Sep 15, 2025

This PR adds a new complex.powi operation to MLIR's complex dialect for computing complex numbers raised to integer powers.

Key changes include:

  • Addition of the new PowiOp operation definition in the Complex dialect
  • Integration with algebraic simplification passes for optimization
  • Support for conversion to ROCDL library calls
  • Updates to Flang frontend to generate the new operation

This depends on #158642.

@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:complex MLIR complex dialect mlir:math flang:fir-hlfir labels Sep 15, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 15, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-flang-fir-hlfir

Author: Akash Banerjee (TIFitis)

Changes

Add a new powi op to the complex dialect which takes a complex number and an integer exponent.
Add lowering changes to make use of this new op, along with conversion pass supports.
Also add complex.powi op to the AlgebraicSimplification pass for simple optimisations.


Patch is 24.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158722.diff

14 Files Affected:

  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+13-7)
  • (modified) flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp (+42-50)
  • (modified) flang/test/Lower/HLFIR/binary-ops.f90 (+1-1)
  • (modified) flang/test/Lower/Intrinsics/pow_complex16i.f90 (+1-1)
  • (modified) flang/test/Lower/Intrinsics/pow_complex16k.f90 (+1-1)
  • (modified) flang/test/Lower/amdgcn-complex.f90 (+9)
  • (modified) flang/test/Lower/power-operator.f90 (+4-5)
  • (modified) flang/test/Transforms/convert-complex-pow.fir (+12-30)
  • (modified) mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td (+26)
  • (modified) mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp (+37-4)
  • (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+17-7)
  • (modified) mlir/lib/Dialect/Math/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir (+14)
  • (added) mlir/test/Dialect/Complex/powi-simplify.mlir (+20)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 466458c05dba7..74a4e8f85c8ff 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
     return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
   auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
   mlir::Value exp = args[1];
-  if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
-    auto realTy = complexTy.getElementType();
-    mlir::Value realExp = builder.createConvert(loc, realTy, exp);
-    mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
-    exp =
-        builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+  mlir::Value result;
+  if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
+      mlir::isa<mlir::IndexType>(exp.getType())) {
+    result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
+  } else {
+    if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+      auto realTy = complexTy.getElementType();
+      mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+      mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+      exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
+                                                    zero);
+    }
+    result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
   }
-  mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
   result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
   return result;
 }
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index dced5f90d6924..42f5df160798c 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -61,63 +61,55 @@ void ConvertComplexPowPass::runOnOperation() {
 
   fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
 
-  mod.walk([&](complex::PowOp op) {
+  mod.walk([&](complex::PowiOp op) {
     builder.setInsertionPoint(op);
     Location loc = op.getLoc();
     auto complexTy = cast<ComplexType>(op.getType());
     auto elemTy = complexTy.getElementType();
-
     Value base = op.getLhs();
-    Value rhs = op.getRhs();
-
-    Value intExp;
-    if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
-      if (isZero(create.getImaginary())) {
-        if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
-          if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
-            intExp = conv.getValue();
-        }
-      }
-    }
-
+    Value intExp = op.getRhs();
     func::FuncOp callee;
-    SmallVector<Value> args;
-    if (intExp) {
-      unsigned realBits = cast<FloatType>(elemTy).getWidth();
-      unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
-      auto funcTy = builder.getFunctionType(
-          {complexTy, builder.getIntegerType(intBits)}, {complexTy});
-      if (realBits == 32 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
-      else if (realBits == 32 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
-      else if (realBits == 64 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
-      else if (realBits == 64 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
-      else if (realBits == 128 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
-      else if (realBits == 128 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
-      else
-        return;
-      args = {base, intExp};
-    } else {
-      unsigned realBits = cast<FloatType>(elemTy).getWidth();
-      auto funcTy =
-          builder.getFunctionType({complexTy, complexTy}, {complexTy});
-      if (realBits == 32)
-        callee = getOrDeclare(builder, loc, "cpowf", funcTy);
-      else if (realBits == 64)
-        callee = getOrDeclare(builder, loc, "cpow", funcTy);
-      else if (realBits == 128)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
-      else
-        return;
-      args = {base, rhs};
-    }
+    unsigned realBits = cast<FloatType>(elemTy).getWidth();
+    unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+    auto funcTy = builder.getFunctionType(
+        {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+    if (realBits == 32 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+    else if (realBits == 32 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+    else if (realBits == 64 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+    else if (realBits == 64 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+    else if (realBits == 128 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+    else if (realBits == 128 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+    else
+      return;
+    auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
+    op.replaceAllUsesWith(call.getResult(0));
+    op.erase();
+  });
 
-    auto call = fir::CallOp::create(builder, loc, callee, args);
+  mod.walk([&](complex::PowOp op) {
+    builder.setInsertionPoint(op);
+    Location loc = op.getLoc();
+    auto complexTy = cast<ComplexType>(op.getType());
+    auto elemTy = complexTy.getElementType();
+    unsigned realBits = cast<FloatType>(elemTy).getWidth();
+    func::FuncOp callee;
+    auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy});
+    if (realBits == 32)
+      callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+    else if (realBits == 64)
+      callee = getOrDeclare(builder, loc, "cpow", funcTy);
+    else if (realBits == 128)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+    else
+      return;
+    auto call =
+        fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()});
     op.replaceAllUsesWith(call.getResult(0));
     op.erase();
   });
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 1fbd333db37c3..7e1691dd1587a 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK:  %[[VAL_8:.*]] = complex.pow
+! CHECK:  %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32
 
 subroutine extremum(c, n, l)
   integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 1827863a57f43..0b26024b02021 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
 ! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
   complex(16) :: a
   integer(4) :: b
   b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index 039dfd5152a06..90a9f5e03628d 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
 ! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
   complex(16) :: a
   integer(8) :: b
   b = a ** b
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index 4ee5de4d2842e..a28eaea82379b 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
    complex :: a, b, c
    a = b**c
 end subroutine pow_test
+
+! CHECK-LABEL: func @_QPpowi_test(
+! CHECK: complex.powi
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine powi_test(a, b, c)
+   complex :: a, b
+   integer :: i
+   b = a ** i
+end subroutine powi_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 3058927144248..9f74d172a6bb2 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
   complex :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
   ! PRECISE: fir.call @_FortranAcpowi
 end subroutine
 
@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
   complex :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
   ! PRECISE: fir.call @_FortranAcpowk
 end subroutine
 
@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
   complex(8) :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
   ! PRECISE: fir.call @_FortranAzpowi
 end subroutine
 
@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
   complex(8) :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
   ! PRECISE: fir.call @_FortranAzpowk
 end subroutine
 
@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
   ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
   ! PRECISE: fir.call @cpow
 end subroutine
-
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
index d980817aba9b9..4555fea61e496 100644
--- a/flang/test/Transforms/convert-complex-pow.fir
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -2,18 +2,12 @@
 
 module {
   func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
-    %c0 = arith.constant 0.000000e+00 : f32
-    %c1 = fir.convert %arg1 : (i32) -> f32
-    %c2 = complex.create %c1, %c0 : complex<f32>
-    %0 = complex.pow %arg0, %c2 : complex<f32>
+    %0 = complex.powi %arg0, %arg1 : complex<f32>, i32
     return %0 : complex<f32>
   }
 
   func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
-    %c0 = arith.constant 0.000000e+00 : f32
-    %c1 = fir.convert %arg1 : (i64) -> f32
-    %c2 = complex.create %c1, %c0 : complex<f32>
-    %0 = complex.pow %arg0, %c2 : complex<f32>
+    %0 = complex.powi %arg0, %arg1 : complex<f32>, i64
     return %0 : complex<f32>
   }
 
@@ -23,18 +17,12 @@ module {
   }
 
   func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
-    %c0 = arith.constant 0.000000e+00 : f64
-    %c1 = fir.convert %arg1 : (i32) -> f64
-    %c2 = complex.create %c1, %c0 : complex<f64>
-    %0 = complex.pow %arg0, %c2 : complex<f64>
+    %0 = complex.powi %arg0, %arg1 : complex<f64>, i32
     return %0 : complex<f64>
   }
 
   func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
-    %c0 = arith.constant 0.000000e+00 : f64
-    %c1 = fir.convert %arg1 : (i64) -> f64
-    %c2 = complex.create %c1, %c0 : complex<f64>
-    %0 = complex.pow %arg0, %c2 : complex<f64>
+    %0 = complex.powi %arg0, %arg1 : complex<f64>, i64
     return %0 : complex<f64>
   }
 
@@ -44,18 +32,12 @@ module {
   }
 
   func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
-    %c0 = arith.constant 0.000000e+00 : f128
-    %c1 = fir.convert %arg1 : (i32) -> f128
-    %c2 = complex.create %c1, %c0 : complex<f128>
-    %0 = complex.pow %arg0, %c2 : complex<f128>
+    %0 = complex.powi %arg0, %arg1 : complex<f128>, i32
     return %0 : complex<f128>
   }
 
   func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
-    %c0 = arith.constant 0.000000e+00 : f128
-    %c1 = fir.convert %arg1 : (i64) -> f128
-    %c2 = complex.create %c1, %c0 : complex<f128>
-    %0 = complex.pow %arg0, %c2 : complex<f128>
+    %0 = complex.powi %arg0, %arg1 : complex<f128>, i64
     return %0 : complex<f128>
   }
 
@@ -67,11 +49,11 @@ module {
 
 // CHECK-LABEL: func.func @pow_c4_i4(
 // CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c4_i8(
 // CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c4_c4(
 // CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> complex<f32>
@@ -79,11 +61,11 @@ module {
 
 // CHECK-LABEL: func.func @pow_c8_i4(
 // CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c8_i8(
 // CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c8_c8(
 // CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex<f64>, complex<f64>) -> complex<f64>
@@ -91,11 +73,11 @@ module {
 
 // CHECK-LABEL: func.func @pow_c16_i4(
 // CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c16_i8(
 // CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c16_c16(
 // CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex<f128>, complex<f128>) -> complex<f128>
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 44590406301eb..ca5103c16889c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PowiOp
+//===----------------------------------------------------------------------===//
+
+def PowiOp : Complex_Op<"powi",
+    [Pure, Elementwise, SameOperandsAndResultShape,
+     AllTypesMatch<["lhs", "result"]>]> {
+  let summary = "complex number raised to integer power";
+  let description = [{
+    The `powi` operation takes a complex number and an integer exponent.
+
+    Example:
+
+    ```mlir
+    %a = complex.powi %b, %c : complex<f32>, i32
+    ```
+  }];
+
+  let arguments = (ins Complex<AnyFloat>:$lhs,
+                       AnySignlessInteger:$rhs);
+  let results = (outs Complex<AnyFloat>:$result);
+
+  let assemblyFormat =
+      "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
+}
+
 //===----------------------------------------------------------------------===//
 // ReOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 0372f32d6b6df..25e5ab49cdb8c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -71,10 +73,40 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
     return success();
   }
 };
+
+// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
+struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
+  using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(complex::PowiOp op,
+                                PatternRewriter &rewriter) const final {
+    auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
+    Type elementType = complexType.getElementType();
+
+    Type exponentType = op.getRhs().getType();
+    Type exponentFloatType = elementType;
+    if (auto shapedType = dyn_cast<ShapedType>(exponentType))
+      exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
+
+    Location loc = op.getLoc();
+    Value exponentReal =
+        rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
+    Value zeroImag = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(exponentFloatType));
+    Value exponent = rewriter.create<complex::CreateOp>(
+        loc, op.getLhs().getType(), exponentReal, zeroImag);
+
+    rewriter
+        .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
+                                            exponent);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
     RewritePatternSet &patterns) {
+  patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
       patterns.getContext(), "__ocml_cabs_f32");
@@ -125,11 +157,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
   populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
 
   ConversionTarget target(getContext());
-  target.addLegalDialect<func::FuncDialect>();
-  target.addLegalOp<complex::MulOp>();
+  target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
+  target.addLegalOp<complex::CreateOp, complex::MulOp>();
   target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
-                      complex::LogOp, complex::PowOp, complex::SinOp,
-                      complex::SqrtOp, complex::TanOp, complex::TanhOp>();
+                      complex::LogOp, complex::PowOp, complex::PowiOp,
+                      complex::SinOp, complex::SqrtOp, complex::TanOp,
+                      complex::TanhOp>();
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 31785eb20a642..3711c112cc631 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
 
   Value one;
   Type opType = getElementTypeOrSelf(op.getType());
-  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
+  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
     one = arith::ConstantOp::create(rewriter, loc,
                                     rewriter.getFloatAttr(opType, 1.0));
-  else
+  } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
+    ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 15, 2025

@llvm/pr-subscribers-mlir-complex

Author: Akash Banerjee (TIFitis)

Changes

Add a new powi op to the complex dialect which takes a complex number and an integer exponent.
Add lowering changes to make use of this new op, along with conversion pass supports.
Also add complex.powi op to the AlgebraicSimplification pass for simple optimisations.


Patch is 24.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158722.diff

14 Files Affected:

  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+13-7)
  • (modified) flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp (+42-50)
  • (modified) flang/test/Lower/HLFIR/binary-ops.f90 (+1-1)
  • (modified) flang/test/Lower/Intrinsics/pow_complex16i.f90 (+1-1)
  • (modified) flang/test/Lower/Intrinsics/pow_complex16k.f90 (+1-1)
  • (modified) flang/test/Lower/amdgcn-complex.f90 (+9)
  • (modified) flang/test/Lower/power-operator.f90 (+4-5)
  • (modified) flang/test/Transforms/convert-complex-pow.fir (+12-30)
  • (modified) mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td (+26)
  • (modified) mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp (+37-4)
  • (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+17-7)
  • (modified) mlir/lib/Dialect/Math/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir (+14)
  • (added) mlir/test/Dialect/Complex/powi-simplify.mlir (+20)
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 466458c05dba7..74a4e8f85c8ff 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
     return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
   auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
   mlir::Value exp = args[1];
-  if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
-    auto realTy = complexTy.getElementType();
-    mlir::Value realExp = builder.createConvert(loc, realTy, exp);
-    mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
-    exp =
-        builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+  mlir::Value result;
+  if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
+      mlir::isa<mlir::IndexType>(exp.getType())) {
+    result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
+  } else {
+    if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+      auto realTy = complexTy.getElementType();
+      mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+      mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+      exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
+                                                    zero);
+    }
+    result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
   }
-  mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
   result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
   return result;
 }
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
index dced5f90d6924..42f5df160798c 100644
--- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -61,63 +61,55 @@ void ConvertComplexPowPass::runOnOperation() {
 
   fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
 
-  mod.walk([&](complex::PowOp op) {
+  mod.walk([&](complex::PowiOp op) {
     builder.setInsertionPoint(op);
     Location loc = op.getLoc();
     auto complexTy = cast<ComplexType>(op.getType());
     auto elemTy = complexTy.getElementType();
-
     Value base = op.getLhs();
-    Value rhs = op.getRhs();
-
-    Value intExp;
-    if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
-      if (isZero(create.getImaginary())) {
-        if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
-          if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
-            intExp = conv.getValue();
-        }
-      }
-    }
-
+    Value intExp = op.getRhs();
     func::FuncOp callee;
-    SmallVector<Value> args;
-    if (intExp) {
-      unsigned realBits = cast<FloatType>(elemTy).getWidth();
-      unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
-      auto funcTy = builder.getFunctionType(
-          {complexTy, builder.getIntegerType(intBits)}, {complexTy});
-      if (realBits == 32 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
-      else if (realBits == 32 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
-      else if (realBits == 64 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
-      else if (realBits == 64 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
-      else if (realBits == 128 && intBits == 32)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
-      else if (realBits == 128 && intBits == 64)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
-      else
-        return;
-      args = {base, intExp};
-    } else {
-      unsigned realBits = cast<FloatType>(elemTy).getWidth();
-      auto funcTy =
-          builder.getFunctionType({complexTy, complexTy}, {complexTy});
-      if (realBits == 32)
-        callee = getOrDeclare(builder, loc, "cpowf", funcTy);
-      else if (realBits == 64)
-        callee = getOrDeclare(builder, loc, "cpow", funcTy);
-      else if (realBits == 128)
-        callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
-      else
-        return;
-      args = {base, rhs};
-    }
+    unsigned realBits = cast<FloatType>(elemTy).getWidth();
+    unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+    auto funcTy = builder.getFunctionType(
+        {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+    if (realBits == 32 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+    else if (realBits == 32 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+    else if (realBits == 64 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+    else if (realBits == 64 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+    else if (realBits == 128 && intBits == 32)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+    else if (realBits == 128 && intBits == 64)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+    else
+      return;
+    auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
+    op.replaceAllUsesWith(call.getResult(0));
+    op.erase();
+  });
 
-    auto call = fir::CallOp::create(builder, loc, callee, args);
+  mod.walk([&](complex::PowOp op) {
+    builder.setInsertionPoint(op);
+    Location loc = op.getLoc();
+    auto complexTy = cast<ComplexType>(op.getType());
+    auto elemTy = complexTy.getElementType();
+    unsigned realBits = cast<FloatType>(elemTy).getWidth();
+    func::FuncOp callee;
+    auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy});
+    if (realBits == 32)
+      callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+    else if (realBits == 64)
+      callee = getOrDeclare(builder, loc, "cpow", funcTy);
+    else if (realBits == 128)
+      callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+    else
+      return;
+    auto call =
+        fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()});
     op.replaceAllUsesWith(call.getResult(0));
     op.erase();
   });
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 1fbd333db37c3..7e1691dd1587a 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK:  %[[VAL_8:.*]] = complex.pow
+! CHECK:  %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32
 
 subroutine extremum(c, n, l)
   integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 1827863a57f43..0b26024b02021 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
 ! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
   complex(16) :: a
   integer(4) :: b
   b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index 039dfd5152a06..90a9f5e03628d 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -4,7 +4,7 @@
 ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
 
 ! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
-! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
+! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
   complex(16) :: a
   integer(8) :: b
   b = a ** b
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index 4ee5de4d2842e..a28eaea82379b 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
    complex :: a, b, c
    a = b**c
 end subroutine pow_test
+
+! CHECK-LABEL: func @_QPpowi_test(
+! CHECK: complex.powi
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine powi_test(a, b, c)
+   complex :: a, b
+   integer :: i
+   b = a ** i
+end subroutine powi_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 3058927144248..9f74d172a6bb2 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
   complex :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
   ! PRECISE: fir.call @_FortranAcpowi
 end subroutine
 
@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
   complex :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
   ! PRECISE: fir.call @_FortranAcpowk
 end subroutine
 
@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
   complex(8) :: x, z
   integer :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
   ! PRECISE: fir.call @_FortranAzpowi
 end subroutine
 
@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
   complex(8) :: x, z
   integer(8) :: y
   z = x ** y
-  ! CHECK: complex.pow
+  ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
   ! PRECISE: fir.call @_FortranAzpowk
 end subroutine
 
@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
   ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
   ! PRECISE: fir.call @cpow
 end subroutine
-
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
index d980817aba9b9..4555fea61e496 100644
--- a/flang/test/Transforms/convert-complex-pow.fir
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -2,18 +2,12 @@
 
 module {
   func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
-    %c0 = arith.constant 0.000000e+00 : f32
-    %c1 = fir.convert %arg1 : (i32) -> f32
-    %c2 = complex.create %c1, %c0 : complex<f32>
-    %0 = complex.pow %arg0, %c2 : complex<f32>
+    %0 = complex.powi %arg0, %arg1 : complex<f32>, i32
     return %0 : complex<f32>
   }
 
   func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
-    %c0 = arith.constant 0.000000e+00 : f32
-    %c1 = fir.convert %arg1 : (i64) -> f32
-    %c2 = complex.create %c1, %c0 : complex<f32>
-    %0 = complex.pow %arg0, %c2 : complex<f32>
+    %0 = complex.powi %arg0, %arg1 : complex<f32>, i64
     return %0 : complex<f32>
   }
 
@@ -23,18 +17,12 @@ module {
   }
 
   func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
-    %c0 = arith.constant 0.000000e+00 : f64
-    %c1 = fir.convert %arg1 : (i32) -> f64
-    %c2 = complex.create %c1, %c0 : complex<f64>
-    %0 = complex.pow %arg0, %c2 : complex<f64>
+    %0 = complex.powi %arg0, %arg1 : complex<f64>, i32
     return %0 : complex<f64>
   }
 
   func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
-    %c0 = arith.constant 0.000000e+00 : f64
-    %c1 = fir.convert %arg1 : (i64) -> f64
-    %c2 = complex.create %c1, %c0 : complex<f64>
-    %0 = complex.pow %arg0, %c2 : complex<f64>
+    %0 = complex.powi %arg0, %arg1 : complex<f64>, i64
     return %0 : complex<f64>
   }
 
@@ -44,18 +32,12 @@ module {
   }
 
   func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
-    %c0 = arith.constant 0.000000e+00 : f128
-    %c1 = fir.convert %arg1 : (i32) -> f128
-    %c2 = complex.create %c1, %c0 : complex<f128>
-    %0 = complex.pow %arg0, %c2 : complex<f128>
+    %0 = complex.powi %arg0, %arg1 : complex<f128>, i32
     return %0 : complex<f128>
   }
 
   func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
-    %c0 = arith.constant 0.000000e+00 : f128
-    %c1 = fir.convert %arg1 : (i64) -> f128
-    %c2 = complex.create %c1, %c0 : complex<f128>
-    %0 = complex.pow %arg0, %c2 : complex<f128>
+    %0 = complex.powi %arg0, %arg1 : complex<f128>, i64
     return %0 : complex<f128>
   }
 
@@ -67,11 +49,11 @@ module {
 
 // CHECK-LABEL: func.func @pow_c4_i4(
 // CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c4_i8(
 // CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c4_c4(
 // CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> complex<f32>
@@ -79,11 +61,11 @@ module {
 
 // CHECK-LABEL: func.func @pow_c8_i4(
 // CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c8_i8(
 // CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c8_c8(
 // CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex<f64>, complex<f64>) -> complex<f64>
@@ -91,11 +73,11 @@ module {
 
 // CHECK-LABEL: func.func @pow_c16_i4(
 // CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c16_i8(
 // CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
-// CHECK-NOT: complex.pow
+// CHECK-NOT: complex.powi
 
 // CHECK-LABEL: func.func @pow_c16_c16(
 // CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex<f128>, complex<f128>) -> complex<f128>
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 44590406301eb..ca5103c16889c 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PowiOp
+//===----------------------------------------------------------------------===//
+
+def PowiOp : Complex_Op<"powi",
+    [Pure, Elementwise, SameOperandsAndResultShape,
+     AllTypesMatch<["lhs", "result"]>]> {
+  let summary = "complex number raised to integer power";
+  let description = [{
+    The `powi` operation takes a complex number and an integer exponent.
+
+    Example:
+
+    ```mlir
+    %a = complex.powi %b, %c : complex<f32>, i32
+    ```
+  }];
+
+  let arguments = (ins Complex<AnyFloat>:$lhs,
+                       AnySignlessInteger:$rhs);
+  let results = (outs Complex<AnyFloat>:$result);
+
+  let assemblyFormat =
+      "$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
+}
+
 //===----------------------------------------------------------------------===//
 // ReOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 0372f32d6b6df..25e5ab49cdb8c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -71,10 +73,40 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
     return success();
   }
 };
+
+// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
+struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
+  using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(complex::PowiOp op,
+                                PatternRewriter &rewriter) const final {
+    auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
+    Type elementType = complexType.getElementType();
+
+    Type exponentType = op.getRhs().getType();
+    Type exponentFloatType = elementType;
+    if (auto shapedType = dyn_cast<ShapedType>(exponentType))
+      exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
+
+    Location loc = op.getLoc();
+    Value exponentReal =
+        rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
+    Value zeroImag = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(exponentFloatType));
+    Value exponent = rewriter.create<complex::CreateOp>(
+        loc, op.getLhs().getType(), exponentReal, zeroImag);
+
+    rewriter
+        .replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
+                                            exponent);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
     RewritePatternSet &patterns) {
+  patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
       patterns.getContext(), "__ocml_cabs_f32");
@@ -125,11 +157,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
   populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
 
   ConversionTarget target(getContext());
-  target.addLegalDialect<func::FuncDialect>();
-  target.addLegalOp<complex::MulOp>();
+  target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
+  target.addLegalOp<complex::CreateOp, complex::MulOp>();
   target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
-                      complex::LogOp, complex::PowOp, complex::SinOp,
-                      complex::SqrtOp, complex::TanOp, complex::TanhOp>();
+                      complex::LogOp, complex::PowOp, complex::PowiOp,
+                      complex::SinOp, complex::SqrtOp, complex::TanOp,
+                      complex::TanhOp>();
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 31785eb20a642..3711c112cc631 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
 
   Value one;
   Type opType = getElementTypeOrSelf(op.getType());
-  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
+  if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
     one = arith::ConstantOp::create(rewriter, loc,
                                     rewriter.getFloatAttr(opType, 1.0));
-  else
+  } else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
+    ...
[truncated]

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new complex.powi operation to MLIR's complex dialect for computing complex numbers raised to integer powers. The implementation provides more efficient handling of complex-to-integer power operations compared to the generic complex power operation.

Key changes include:

  • Addition of the new PowiOp operation definition in the Complex dialect
  • Integration with algebraic simplification passes for optimization
  • Support for conversion to ROCDL library calls
  • Updates to Flang frontend to generate the new operation

Reviewed Changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td Defines the new PowiOp operation with complex base and integer exponent
mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp Adds algebraic simplification patterns for complex.powi operations
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp Implements conversion from complex.powi to ROCDL library calls
flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp Updates complex power conversion to handle both powi and pow operations
flang/lib/Optimizer/Builder/IntrinsicCall.cpp Modifies intrinsic call generation to use powi for integer exponents
Multiple test files Updates test expectations to reflect the new powi operation usage

auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
one = rewriter.create<complex::ConstantOp>(
loc, complexTy, rewriter.getArrayAttr({realPart, imagPart}));
} else {
Copy link
Preview

Copilot AI Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The nested if-else chain with constexpr conditions could be simplified using if constexpr for all branches to improve readability and consistency.

Suggested change
} else {
} else if constexpr (true) {

Copilot uses AI. Check for mistakes.

Comment on lines +109 to 112
patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
Copy link
Preview

Copilot AI Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider reordering these pattern additions to maintain alphabetical order for better code organization and maintainability.

Copilot uses AI. Check for mistakes.

Copy link

github-actions bot commented Sep 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!

Just a couple of minor comments.

@joker-eph
Copy link
Collaborator

What is the LLVM lowering here?

AllTypesMatch<["lhs", "result"]>]> {
let summary = "complex number raised to integer power";
let description = [{
The `powi` operation takes a complex number and an integer exponent.
Copy link
Collaborator

@joker-eph joker-eph Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signed aspects of the operands, as well at the overflow (or other special) behaviors should be specified here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the description to be consistent with other similar Powi ops. Let me know if I'm missing anything.

@vzakhari
Copy link
Contributor

What is the LLVM lowering here?

Good point. We can try to reuse the conversion for math::FPowIOp: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp#L379

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once the existing comments are addressed.

@TIFitis TIFitis force-pushed the users/Akash/complex_powi branch from 45a1331 to 6faad70 Compare September 17, 2025 21:17
@TIFitis
Copy link
Member Author

TIFitis commented Sep 17, 2025

What is the LLVM lowering here?

Good point. We can try to reuse the conversion for math::FPowIOp: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp#L379

@joker-eph @vzakhari At the moment both ConvertComplexPow and ConvertComplexToROCDL convert complex.powi to some library calls and no further lowering is needed. Is it okay if we convert complex.powi to complex.pow as a fallback and let that take care of LLVM lowering?

@joker-eph
Copy link
Collaborator

ConvertComplexPow

That isn't in MLIR right now, so that's not generally usable.

Is it okay if we convert complex.powi to complex.pow as a fallback and let that take care of LLVM lowering?

I'm confused: are these the same op? I would assume the semantics differs of your wouldn't add a new op. So how com you can just convert one to the other?

@TIFitis
Copy link
Member Author

TIFitis commented Sep 17, 2025

I'm confused: are these the same op? I would assume the semantics differs of your wouldn't add a new op. So how com you can just convert one to the other?

You can always convert powi to pow, by casting the integer exponent to a complex exponent with no imaginary part.

The introduction of powi is only for enabling algebraic simplifications that we can perform on an integer only exponent similar to powi ops in other dialects.

@TIFitis
Copy link
Member Author

TIFitis commented Sep 17, 2025

That isn't in MLIR right now, so that's not generally usable.

I've added complex.powi -> complex.pow conversion to the ComplexToStandard MLIR pass.

@joker-eph
Copy link
Collaborator

That isn't in MLIR right now, so that's not generally usable.

I've added complex.powi -> complex.pow conversion to the ComplexToStandard MLIR pass.

Thanks, LG!

});

auto call = fir::CallOp::create(builder, loc, callee, args);
mod.walk([&](complex::PowOp op) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not walk multiple times if we can do it in a single traversal, can you replace this with a walk on Operation* and dispatch inside the walk?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this.

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The powi part looks good to me. Are you planning to merge it, and then rebase the other PR for the Flang changes for the final review?

@TIFitis
Copy link
Member Author

TIFitis commented Sep 18, 2025

The powi part looks good to me. Are you planning to merge it, and then rebase the other PR for the Flang changes for the final review?

I plan on landing both PRs at once. This PR depends on #158642, which should land first.
All the work should have been in a single PR but I split it up to make it easier to review.

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some final comments.

Comment on lines 1275 to 1284
if constexpr (std::is_same_v<T, mlir::complex::PowOp>) {
auto resultType = mathLibFuncType.getResult(0);
result = T::create(builder, loc, resultType, args);
} else if constexpr (std::is_same_v<T, mlir::complex::PowiOp>) {
auto resultType = mathLibFuncType.getResult(0);
auto fmfAttr = mlir::arith::FastMathFlagsAttr::get(
builder.getContext(), builder.getFastMathFlags());
result = builder.create<mlir::complex::PowiOp>(loc, resultType, args[0],
args[1], fmfAttr);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need all this code? I believe just a simple T::create(buider, loc, args) should work, because of the type constraints in the operations definitions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I've simplified it. Thanks for catching.

Type elementType = complexTy.getElementType();
auto realPart = rewriter.getFloatAttr(elementType, 1.0);
auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
one = rewriter.create<complex::ConstantOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe all the create methods of the rewriter will become deprecated soon, so complex::ConstantOp::create is a better alternative. There are other cases below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

TIFitis added a commit that referenced this pull request Sep 19, 2025
This PR introduces a new `ConvertComplexPow` pass for Flang that handles
complex power operations. The change forces lowering to complex.pow
operations when `--math-runtime=precise` is not used, then uses the
`ConvertComplexPow` pass to convert these operations back to library
calls.

- Adds a new `ConvertComplexPow` pass that converts complex.pow ops to
appropriate runtime library calls
- Updates complex power lowering to use `complex.pow` operations by
default instead of direct library calls

#158722 Adds a new `complex.powi` op enabling algebraic optimisations.
Base automatically changed from users/Akash/complex_lowering to main September 19, 2025 00:51
@TIFitis TIFitis force-pushed the users/Akash/complex_powi branch from 6dbb370 to 78d9190 Compare September 19, 2025 01:11
@TIFitis TIFitis enabled auto-merge (squash) September 19, 2025 01:12
@TIFitis TIFitis merged commit fdb1f48 into main Sep 19, 2025
9 checks passed
@TIFitis TIFitis deleted the users/Akash/complex_powi branch September 19, 2025 01:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:complex MLIR complex dialect mlir:math mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants